Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Hessian for sparse softmax cross entropy #31700

Merged
merged 11 commits into from
Feb 10, 2020
Merged

Implement Hessian for sparse softmax cross entropy #31700

merged 11 commits into from
Feb 10, 2020

Conversation

mknbv
Copy link
Contributor

@mknbv mknbv commented Aug 16, 2019

This reapplies #22231 with implementation of _IsZero in eager mode.

@tensorflow-bot tensorflow-bot bot added the size:M CL Change Size: Medium label Aug 16, 2019
@rthadur rthadur self-assigned this Aug 16, 2019
@rthadur rthadur added this to Assigned Reviewer in PR Queue via automation Aug 16, 2019
@rthadur rthadur added the comp:eager Eager related issues label Aug 16, 2019
@mellanox-github
Copy link

Can one of the admins verify this patch?

@allenlavoie
Copy link
Member

So this was actually Alex's suggestion originally to use @tf.custom_gradient so we don't run into the need to check for zeros. It's a bit complicated to do at the moment; I have a CL out for review which will hopefully make defining second-order gradients with custom_gradient easier. But something like this passes the hessian the test you've written:

def _sparse_softmax_cross_entropy_with_logits_and_gradients(
    logits, labels, name):
  @custom_gradient.custom_gradient
  def _zeroth_order(unused_logits):
    loss, softmax_grad = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
        logits, labels, name=name)
    def _first_order_wrapper(dloss):
      """Swap dloss for logits so we can override a second-order gradient."""
      @custom_gradient.custom_gradient
      def _first_order(unused_logits):
        grad = _BroadcastMul(dloss, softmax_grad)
        def _second_order(ddloss):
          softmax_logits = softmax(logits)
          return ((ddloss - array_ops.squeeze(
              math_ops.matmul(
                  array_ops.expand_dims(ddloss, 1),
                  array_ops.expand_dims(softmax_logits, 2)),
              axis=1)) * softmax_logits)
        return grad, _second_order
      return _first_order(logits)
    return loss, _first_order_wrapper
  return _zeroth_order(logits)

That'd be a helper called instead of gen_nn_ops.sparse_softmax_cross_entropy_with_logits. I think we can do something similar for non-sparse softmax_cross_entropy_with_logits if we split up its labels and logits into separate custom_gradients (since there we need gradients for labels too).

The trick for second-order gradients with nested @tf.custom_gradient is that the inner custom_gradient needs to have the original/primal input as its argument so its gradient can return a second-order gradient for that primal rather than for dloss.

Any objections to this approach rather than the Tensor-tagging approach? Happy to clean the example up or explain further if it's helpful.

@rthadur rthadur added the stat:awaiting response Status - Awaiting response from author label Sep 16, 2019
@rthadur
Copy link
Contributor

rthadur commented Sep 16, 2019

@michaelkonobeev gentle ping to check latest comments , thank you

@mknbv
Copy link
Contributor Author

mknbv commented Sep 17, 2019

I will work on it in soon.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Sep 18, 2019
@mknbv
Copy link
Contributor Author

mknbv commented Oct 1, 2019

@allenlavoie have the CL which makes defining second order gradients with custom_gradient easier been submitted? If no, I could implement it as in the example you provided. Also are there any ideas why the previous PR broke the convergence of NCF keras model with run_eagerly=True? Maybe I could add a test for it? The Hessian computation itself seems correct after I checked it multiple times.

@alextp
Copy link
Contributor

alextp commented Oct 1, 2019 via email

@mknbv
Copy link
Contributor Author

mknbv commented Oct 11, 2019

@allenlavoie in the example you provided why do we call _BroadcastMul(dloss, softmax_grad) inside _first_order and not inside first_order_wrapper? Wouldn't this lead to a part of computation being skipped when computing second order derivatives like in the following example:

@tf.custom_gradient
def f(x):
  y, grad_x = x ** 3, 3 * x ** 2
  
  def first_order_grad(dy):
    @tf.custom_gradient
    def first_order_custom(unused_x):
      def second_order_grad(ddy):
        return 6 * x * ddy
      return grad_x * dy, second_order_grad
    return first_order_custom(x)
  
  return y, first_order_grad

x = tf.Variable([1.])

with tf.GradientTape(persistent=True) as tape:
  y = tf.square(f(x))
  dx = tape.gradient(y, x)

print(dx) # 6, expected 6
ddx = tape.gradient(dx, x)
print(ddx) # 6, but expected 30

If multiplication by dy is moved one line below, then it works as expected.

@allenlavoie
Copy link
Member

You're right, it needs to be outside the inner custom_gradient. And I need to fix that for primals= which will be somewhat tricky.

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 14, 2019

A second component that will stop us from being able to solve this problem using tf.custom_gradient is the fact that custom_gradient is not serialized with the graph, so serializing a savedmodel containing a cross entropy loss using tf.cusotm_gradient, when you deserialize you can get either errors or incorrect behavior if taking gradients of the resulting tensor. So there are two major issues here.

@allenlavoie
Copy link
Member

I don't think the SavedModel issue needs to block this. We will need a workaround, possibly some indication that this custom_gradient is safe to ignore (meaning we'll still only be able to take first-order gradients from SavedModels).

@mknbv
Copy link
Contributor Author

mknbv commented Oct 18, 2019

custom_gradient with primals set should implement vector-Jacobian product, right? Assuming this works in the general case, the downside is that it will prevent some optimizations like the one done in cross entropy gradient function which computes only the product between dependent parts of gradient and Jacobian. I could implement it using the variant with a wrapper function to keep this optimization.

@allenlavoie could you point me to a way of ignoring custom_gradient for serialization or maybe there are other workarounds?

@allenlavoie
Copy link
Member

Yeah the primals= argument has been removed, I don't see a way to support that API (and it was never in a stable release). I've updated the example in the custom_gradient docstring with your suggested fix.

The easiest way to ignore these custom_gradients in SavedModels is probably to set an attribute. I'd make an internal-only version with the extra argument, then add an attribute to this node. Then you can delete the _gradient_op_type attribute here.

@gbaned
Copy link
Contributor

gbaned commented Nov 26, 2019

@michaelkonobeev Could you please check reviewer comments and keep us posted. Thanks!

@gbaned gbaned added the stat:awaiting response Status - Awaiting response from author label Nov 26, 2019
@mknbv
Copy link
Contributor Author

mknbv commented Nov 30, 2019

I implemented Hessian computation through tf.custom_gradient but still need to implement the workaround for SavedModels. Hope to work on this next week.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Dec 1, 2019
@gbaned gbaned added the stat:awaiting response Status - Awaiting response from author label Dec 5, 2019
@mknbv
Copy link
Contributor Author

mknbv commented Dec 10, 2019

I noticed that when SparseSoftmaxCrossEntropyWithLogits from nn_ops_grad.py is removed, using tf.custom_gradient and then wrapping and decorating the function with tf.function causes an exception if working with persistent tape. Consider the following code:

@tf.custom_gradient
def xent(logits, labels):
  loss, grad = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
      logits, labels)
  def grad_fn(dy):
    return tf.expand_dims(dy, -1) * grad, None
  return loss, grad_fn

@tf.function
def function(logits, lables):
  return xent(logits, labels)

logits = tf.Variable([[1., 2.]])
labels = tf.Variable([1])
module = Module()
with tf.GradientTape(persistent=True) as tape:
  loss = function(logits, labels)

grad = tape.gradient(loss, logits)

This will lead to LookupError from here. If the tape is not persistent or xent is called directly, then it works fine. Is this expected behavior? Seems like it is necessary to support this case before moving to savable models as they work with tf.functions.

@rthadur
Copy link
Contributor

rthadur commented Jan 23, 2020

@michaelkonobeev can you please resolve conflicts ?

PR Queue automation moved this from Approved by Reviewer to Reviewer Requested Changes Jan 24, 2020
@tensorflow-bot tensorflow-bot bot removed the ready to pull PR ready for merge process label Jan 24, 2020
mihaimaruseac
mihaimaruseac previously approved these changes Jan 24, 2020
PR Queue automation moved this from Reviewer Requested Changes to Approved by Reviewer Jan 24, 2020
@tensorflow-bot tensorflow-bot bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Jan 24, 2020
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jan 24, 2020
@gbaned
Copy link
Contributor

gbaned commented Feb 5, 2020

@michaelkonobeev Can you please resolve conflicts? Thanks!

@gbaned gbaned added stat:awaiting response Status - Awaiting response from author and removed ready to pull PR ready for merge process labels Feb 5, 2020
PR Queue automation moved this from Approved by Reviewer to Reviewer Requested Changes Feb 7, 2020
PR Queue automation moved this from Reviewer Requested Changes to Approved by Reviewer Feb 7, 2020
@tensorflow-bot tensorflow-bot bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Feb 7, 2020
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Feb 7, 2020
@gbaned gbaned removed the stat:awaiting response Status - Awaiting response from author label Feb 10, 2020
tensorflow-copybara pushed a commit that referenced this pull request Feb 10, 2020
PiperOrigin-RevId: 294219774
Change-Id: I6a5324599b192a080fc78ce715b28107fabbc236
@tensorflow-copybara tensorflow-copybara merged commit 24418e0 into tensorflow:master Feb 10, 2020
PR Queue automation moved this from Approved by Reviewer to Merged Feb 10, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes comp:eager Eager related issues ready to pull PR ready for merge process size:M CL Change Size: Medium
Projects
PR Queue
  
Merged
Development

Successfully merging this pull request may close these issues.

None yet